{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# default_exp problem_types.multi_cls\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# hide\n",
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "Adding new problem weibo_cws, problem type: seq_tag\n",
      "Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "Adding new problem weibo_fake_cls, problem type: cls\n",
      "Adding new problem weibo_masklm, problem type: masklm\n",
      "Adding new problem weibo_pretrain, problem type: pretrain\n",
      "Adding new problem weibo_fake_regression, problem type: regression\n",
      "Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "Adding new problem weibo_premask_mlm, problem type: premask_mlm\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n",
      "/data/m3tl/m3tl/read_write_tfrecord.py:83: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  elif np.issubdtype(type(feature), np.float):\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "INFO:tensorflow:sampling weights: \n",
      "INFO:tensorflow:weibo_fake_cls_weibo_fake_ner: 0.07042253521126761\n",
      "INFO:tensorflow:weibo_fake_multi_cls: 0.07042253521126761\n",
      "INFO:tensorflow:weibo_masklm: 0.06690140845070422\n",
      "INFO:tensorflow:weibo_pretrain: 0.7922535211267606\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Multi-label classification(multi_cls)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Multi-label classification. By default this problem will use `[CLS]` token embedding.\n",
    "\n",
    "Example: `m3tl.predefined_problems.get_weibo_fake_multi_cls_fn`."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Imports and utils\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# export\n",
    "import pickle\n",
    "from typing import List\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling)\n",
    "from m3tl.special_tokens import PREDICT, TRAIN\n",
    "from m3tl.utils import (LabelEncoder, get_label_encoder_save_path, get_phase,\n",
    "                        need_make_label_encoder, variable_summaries)\n",
    "from sklearn.preprocessing import MultiLabelBinarizer\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Top Layer"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# export\n",
    "\n",
    "class MultiLabelClassification(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str) -> None:\n",
    "        super(MultiLabelClassification, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "        self.dense = tf.keras.layers.Dense(\n",
    "            self.params.get_problem_info(problem=problem_name, info_name='num_classes'))\n",
    "        self.dropout = tf.keras.layers.Dropout(\n",
    "            1-self.params.dropout_keep_prob\n",
    "        )\n",
    "        # self.metric_fn = tfa.metrics.F1Score(\n",
    "        #     num_classes=self.params.num_classes[problem_name],\n",
    "        #     threshold=self.params.multi_cls_threshold,\n",
    "        #     average='macro',\n",
    "        #     name='{}_f1'.format(problem_name))\n",
    "\n",
    "    def call(self, inputs):\n",
    "        mode = get_phase()\n",
    "        training = (mode == TRAIN)\n",
    "        feature, hidden_feature = inputs\n",
    "        hidden_feature = hidden_feature['pooled']\n",
    "        if mode != PREDICT:\n",
    "            labels = feature['{}_label_ids'.format(self.problem_name)]\n",
    "        else:\n",
    "            labels = None\n",
    "        hidden_feature = self.dropout(hidden_feature, training)\n",
    "        logits = self.dense(hidden_feature)\n",
    "\n",
    "        if self.params.detail_log:\n",
    "            for weight_variable in self.weights:\n",
    "                variable_summaries(weight_variable, self.problem_name)\n",
    "\n",
    "        if mode != PREDICT:\n",
    "            labels = tf.cast(labels, tf.float32)\n",
    "            # use weighted loss\n",
    "            label_weights = self.params.multi_cls_positive_weight\n",
    "\n",
    "            def _loss_fn_wrapper(x, y, from_logits=True):\n",
    "                return tf.nn.weighted_cross_entropy_with_logits(x, y, pos_weight=label_weights, name='{}_loss'.format(self.problem_name))\n",
    "            loss = empty_tensor_handling_loss(\n",
    "                labels, logits, _loss_fn_wrapper)\n",
    "            loss = nan_loss_handling(loss)\n",
    "            self.add_loss(loss)\n",
    "            # labels = create_dummy_if_empty(labels)\n",
    "            # logits = create_dummy_if_empty(logits)\n",
    "            # f1 = self.metric_fn(labels, logits)\n",
    "            # self.add_metric(f1)\n",
    "\n",
    "        return tf.nn.sigmoid(\n",
    "            logits, name='%s_predict' % self.problem_name)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "test_top_layer(MultiLabelClassification, problem='weibo_fake_multi_cls', params=params, sample_features=one_batch, hidden_dim=hidden_dim)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Testing MultiLabelClassification\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Get or make label encoder function\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# export\n",
    "def multi_cls_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "\n",
    "    le_path = get_label_encoder_save_path(params=params, problem=problem)\n",
    "\n",
    "    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):\n",
    "        # fit and save label encoder\n",
    "        label_encoder = MultiLabelBinarizer()\n",
    "        label_encoder.fit(label_list)\n",
    "        pickle.dump(label_encoder, open(le_path, 'wb'))\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=label_encoder.classes_.shape[0])\n",
    "    else:\n",
    "        label_encoder = pickle.load(open(le_path, 'rb'))\n",
    "\n",
    "    return label_encoder"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Label handing function"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# export\n",
    "def multi_cls_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    label_id = label_encoder.transform([target])[0]\n",
    "    label_id = np.int32(label_id)\n",
    "    return label_id, None\n",
    "\n"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}